Skip to content

feat(experimental): enable DTA training for Archon DP#1391

Open
ezoicoder wants to merge 1 commit into
areal-project:mainfrom
ezoicoder:feat/zero1-dta-archon-dp
Open

feat(experimental): enable DTA training for Archon DP#1391
ezoicoder wants to merge 1 commit into
areal-project:mainfrom
ezoicoder:feat/zero1-dta-archon-dp

Conversation

@ezoicoder
Copy link
Copy Markdown
Collaborator

@ezoicoder ezoicoder commented Jun 5, 2026

Description

Add a Dynamic Tree Attention path for Archon data-parallel training so shared-prefix rollout trajectories can be trained with block-wise backward while unsupported engines remain explicit.

Key changes:

  • Add tree_training_mode=dta and rollout-level DTA allocation config
  • Route Archon train and forward batches through the DTA wrapper
  • Add trie, allocation, runner, Zero1, and KV-cache model support
  • Report DTA allocation metrics during distributed rollout
  • Add examples, docs, and torchrun regression coverage against baseline DP

Related Issue

N/A

Type of Change

  • Bug fix
  • New feature
  • Breaking change
  • Documentation update
  • Refactoring
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated if applicable
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

N/A

Additional Context

Validation run:

source .venv/bin/activate && pre-commit run --all-files
uv run pytest tests/experimental/dta -q

DTA test result:

8 passed in 170.65s (0:02:50)

DTA test commands:

# Fast DTA unit tests; excludes tests marked slow.
uv run pytest tests/experimental/dta -m "not slow" -q

# Slow DTA tests only. These require CUDA and 2 GPUs.
uv run pytest tests/experimental/dta -m slow -q

# Full DTA test suite, including slow tests.
uv run pytest tests/experimental/dta -q

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Dynamic Tree Attention (DTA) as a new tree training mode, replacing the boolean enable_tree_training flag with a multi-option tree_training_mode string. It adds the areal/experimental/dta module, integrates DTA into the Archon engine via a DTAWrapper, and updates attention mechanisms and Qwen2/Qwen3 models to support KV-cache attention. The review feedback highlights critical runtime AttributeError risks across the Qwen2/Qwen3 models and the DTA runner, where the code incorrectly assumes DynamicCache has a .layers attribute instead of using its standard key_cache and value_cache structures.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +399 to +401
past_len = 0
if len(past_key_values.layers) > 0:
past_len = int(past_key_values.layers[0].keys.shape[2])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes past_key_values has a .layers attribute. However, the standard transformers.cache_utils.DynamicCache class stores key/value states in key_cache and value_cache lists and does not have a .layers attribute. This will cause an AttributeError at runtime. Use get_seq_length() instead.

Suggested change
past_len = 0
if len(past_key_values.layers) > 0:
past_len = int(past_key_values.layers[0].keys.shape[2])
past_len = past_key_values.get_seq_length()

Comment on lines +431 to +433
if past_key_values is not None and layer_idx < len(past_key_values.layers):
layer_entry = past_key_values.layers[layer_idx]
layer_past = (layer_entry.keys, layer_entry.values)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes past_key_values has a .layers attribute. Standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.

Suggested change
if past_key_values is not None and layer_idx < len(past_key_values.layers):
layer_entry = past_key_values.layers[layer_idx]
layer_past = (layer_entry.keys, layer_entry.values)
if past_key_values is not None and layer_idx < len(past_key_values):
layer_past = (past_key_values.key_cache[layer_idx], past_key_values.value_cache[layer_idx])

Comment on lines +513 to +515
past_len = 0
if len(past_key_values.layers) > 0:
past_len = int(past_key_values.layers[0].keys.shape[2])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes past_key_values has a .layers attribute. However, the standard transformers.cache_utils.DynamicCache class stores key/value states in key_cache and value_cache lists and does not have a .layers attribute. This will cause an AttributeError at runtime. Use get_seq_length() instead.

Suggested change
past_len = 0
if len(past_key_values.layers) > 0:
past_len = int(past_key_values.layers[0].keys.shape[2])
past_len = past_key_values.get_seq_length()

Comment on lines +549 to +551
if past_key_values is not None and layer_idx < len(past_key_values.layers):
layer_entry = past_key_values.layers[layer_idx]
layer_past = (layer_entry.keys, layer_entry.values)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes past_key_values has a .layers attribute. Standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.

Suggested change
if past_key_values is not None and layer_idx < len(past_key_values.layers):
layer_entry = past_key_values.layers[layer_idx]
layer_past = (layer_entry.keys, layer_entry.values)
if past_key_values is not None and layer_idx < len(past_key_values):
layer_past = (past_key_values.key_cache[layer_idx], past_key_values.value_cache[layer_idx])

Comment on lines +263 to +270
new_cache = out.past_key_values
for layer_idx, layer in enumerate(new_cache.layers):
self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[
:, :, start:end, :
]
self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[
:, :, start:end, :
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes out.past_key_values has a .layers attribute. However, standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.

Suggested change
new_cache = out.past_key_values
for layer_idx, layer in enumerate(new_cache.layers):
self.kv_cache[0][layer_idx][:, :, start:end, :] = layer.keys[
:, :, start:end, :
]
self.kv_cache[1][layer_idx][:, :, start:end, :] = layer.values[
:, :, start:end, :
]
new_cache = out.past_key_values
for layer_idx in range(len(new_cache)):
self.kv_cache[0][layer_idx][:, :, start:end, :] = new_cache.key_cache[layer_idx][
:, :, start:end, :
]
self.kv_cache[1][layer_idx][:, :, start:end, :] = new_cache.value_cache[layer_idx][
:, :, start:end, :
]

Comment on lines +550 to +559
for layer_idx, layer in enumerate(block_cache.layers):
k = layer.keys[:, :, start:end, :]
v = layer.values[:, :, start:end, :]
roots.extend([k, v])
grads.extend(
[
self.grad_kv[0][layer_idx][:, :, start:end, :],
self.grad_kv[1][layer_idx][:, :, start:end, :],
]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes block_cache has a .layers attribute. However, standard DynamicCache uses key_cache and value_cache lists. Accessing .layers will cause an AttributeError at runtime.

Suggested change
for layer_idx, layer in enumerate(block_cache.layers):
k = layer.keys[:, :, start:end, :]
v = layer.values[:, :, start:end, :]
roots.extend([k, v])
grads.extend(
[
self.grad_kv[0][layer_idx][:, :, start:end, :],
self.grad_kv[1][layer_idx][:, :, start:end, :],
]
)
for layer_idx in range(len(block_cache)):
k = block_cache.key_cache[layer_idx][:, :, start:end, :]
v = block_cache.value_cache[layer_idx][:, :, start:end, :]
roots.extend([k, v])
grads.extend(
[
self.grad_kv[0][layer_idx][:, :, start:end, :],
self.grad_kv[1][layer_idx][:, :, start:end, :],
]
)

@ezoicoder ezoicoder force-pushed the feat/zero1-dta-archon-dp branch 4 times, most recently from 72acd4b to 98e179e Compare June 7, 2026 04:47
@ezoicoder ezoicoder changed the title refactor(experimental): consolidate DTA Archon integration feat(experimental): enable DTA training for Archon DP Jun 7, 2026
Add a Dynamic Tree Attention path for Archon data-parallel training so shared-prefix rollout trajectories can be trained with block-wise backward while keeping unsupported engines explicit.

Key changes:

- Add tree_training_mode=dta and rollout-level DTA allocation config

- Route Archon train and forward batches through the DTA wrapper

- Add trie, allocation, runner, Zero1, and KV-cache model support

- Report DTA allocation metrics during distributed rollout

- Add examples, docs, and torchrun regression coverage against baseline DP
@ezoicoder ezoicoder force-pushed the feat/zero1-dta-archon-dp branch from 98e179e to 35508ea Compare June 7, 2026 04:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant